#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@File    :   AccuracyMonitorPytorch.py    
@Contact :   ruaqy@qq.com
@License :   (C)Copyright 2022-~, GPL 2.0

@Modify Time      @Author    @Version    @Description
------------      -------    --------    -----------
2022/5/16 23:30   rqy        1.0         None
"""

import torch


def accuracy(net, dts, device):
    acc_sum, n = 0, 0
    with torch.no_grad():
        for x, y in dts:
            x = x.to(device)
            acc_sum += torch.sum(net(x).argmax(dim=1) == y.to(device))
            n += y.shape[0]
    return acc_sum / n
